"""Classes to describe individual events in Monte Carlo simulations."""

import logging

from mcmc.events.criterion import AcceptanceCriterion as Criterion
from mcmc.events.proposal import Proposal
from mcmc.slab import change_site
from mcmc.system import SurfaceSystem


class Event:
    """Base class for Monte Carlo events."""

    def __init__(
        self, system: SurfaceSystem, proposal: Proposal, criterion: Criterion, **kwargs
    ) -> None:
        """Initialize the base Event.

        Args:
            system (SurfaceSystem): The surface system to propose changes to.
            proposal (Proposal): The proposal object to generate the change.
            criterion (Criterion): The criterion object to determine acceptance or rejection.
            **kwargs: Additional keyword arguments.

        Attributes:
            system (SurfaceSystem): The surface system to propose changes to.
            logger (logging.Logger): The logger object.
            proposal (Proposal): The proposal object to generate the change.
            criterion (Criterion): The criterion object to determine acceptance or rejection.
            kwargs: Additional keyword arguments.
        """
        self.system = system
        self.logger = system.logger or logging.getLogger(__name__)
        self.proposal = proposal
        self.criterion = criterion
        self.kwargs = kwargs

    def forward(self) -> None:
        """Perform the forward step of the event.

        Raises:
            NotImplementedError: This method should be implemented in the derived classes.
        """
        raise NotImplementedError

    def backward(self) -> None:
        """Perform the backward step of the event and restores the system to the state before the
        change.
        """
        self.system.restore_state("before")

    def acceptance(self, **kwargs) -> tuple[bool, SurfaceSystem]:
        """Perform the acceptance step of the event and determine whether the change is accepted or
        rejected. If rejected, the system state is restored to the state before the change.

        Returns:
            tuple[bool, SurfaceSystem]: A tuple containing a boolean indicating whether the change
                is accepted or rejected, and the surface system after the change.
        """
        self.forward()
        accept = self.criterion(self.system, **kwargs)
        if not accept:
            self.backward()
            self.logger.debug("state not changed!")
        else:
            self.logger.debug("state changed!")

        return accept, self.system


class Change(Event):
    """Semigrand Canonical Monte Carlo event for changing the adsorbate at one site."""

    def __init__(
        self, system: SurfaceSystem, proposal: Proposal, criterion: Criterion, **kwargs
    ) -> None:
        """Initialize the change Event.

        Args:
            system (SurfaceSystem): The surface system to propose changes to.
            proposal (Proposal): The proposal object to generate the change.
            criterion (Criterion): The criterion object to determine acceptance or rejection.
            **kwargs: Additional keyword arguments.

        Attributes:
            action (dict): The action generated by the proposal object.
            site_idx (int): The index of the site where the change is proposed.
            start_ads (str): The adsorbate at the site before the change.
            end_ads (str): The adsorbate at the site after the change.
        """
        super().__init__(system, proposal, criterion, **kwargs)
        self.action = self.proposal.get_action()
        self.site_idx = self.action["site_idx"]
        self.start_ads = self.action["start_ads"]
        self.end_ads = self.action["end_ads"]

    def forward(self) -> None:
        """Perform the forward step of the event and saves the state before and after the change."""
        self.system.save_state("before")
        self.system = change_site(
            self.system,
            self.site_idx,
            self.end_ads,
        )
        self.system.save_state("after")
        self.logger.debug("after proposed state is")
        self.logger.debug(self.system.occ)


class Exchange(Event):
    """Canonical Monte Carlo event for exchanging the adsorbates at two sites."""

    def __init__(
        self, system: SurfaceSystem, proposal: Proposal, criterion: Criterion, **kwargs
    ) -> None:
        """Initialize the exchange Event.

        Args:
            system (SurfaceSystem): The surface system to propose changes to.
            proposal (Proposal): The proposal object to generate the change.
            criterion (Criterion): The criterion object to determine acceptance or rejection.
            **kwargs: Additional keyword arguments.

        Attributes:
            action (dict): The action generated by the proposal object.
            site1_idx (int): The index of the first site where the change is proposed.
            site2_idx (int): The index of the second site where the change is proposed.
            site1_ads (str): The adsorbate at the first site before the change.
            site2_ads (str): The adsorbate at the second site before the change.
        """
        super().__init__(system, proposal, criterion, **kwargs)
        self.action = self.proposal.get_action()
        self.site1_idx = self.action["site1_idx"]
        self.site2_idx = self.action["site2_idx"]
        self.site1_ads = self.action["site1_ads"]
        self.site2_ads = self.action["site2_ads"]

    def forward(self) -> None:
        """Perform the forward step of the event and saves the state before and after the change."""
        self.system.save_state("before")
        # effectively switch the adsorbates at the two sites
        self.system = change_site(
            self.system,
            self.site1_idx,
            self.site2_ads,
        )
        self.system = change_site(
            self.system,
            self.site2_idx,
            self.site1_ads,
        )
        self.system.save_state("after")
        # make sure num atoms is conserved
        self.logger.debug("after proposed state is")
        self.logger.debug(self.system.occ)
